{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Week 3 homework\n",
    "采用xgboost模型完成商品分类（需进行参数调优）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#导入直接调用 xgboost的工具包\n",
    "import xgboost as xgb\n",
    "#导入sklearn XGBClassifier工具包\n",
    "from xgboost import XGBClassifier\n",
    "#导入通用工具包\n",
    "import pandas as pd \n",
    "import numpy as np\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from sklearn.metrics import log_loss\n",
    "\n",
    "from matplotlib import pyplot\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#导入数据，直接调用xgboost的数据为dtrain 和dtest\n",
    "dpath = './data/'\n",
    "dtrain = xgb.DMatrix(dpath + 'RentListingInquries_FE_train.bin')\n",
    "dtest = xgb.DMatrix(dpath + 'RentListingInquries_FE_test.bin')\n",
    "#XGBClassifier使用数据为\n",
    "train = pd.read_csv(dpath +\"RentListingInquries_FE_train.csv\")\n",
    "test = pd.read_csv(dpath +\"RentListingInquries_FE_test.csv\")\n",
    "train = train.drop([\"interest_level\"], axis=1)\n",
    "X_train = np.array(train)\n",
    "X_test = np.array(test)\n",
    "y_train = dtrain.get_label()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据特征已经做过了，大致看下数据结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "227"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtrain.num_col()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "49352"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtrain.num_row()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "227"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtest.num_col()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "74659"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtest.num_row() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "dtest的行数和dtrain的行数一样，但是dtest少了label，为何不一样？\n",
    "\n",
    "看看各类样本分布是否均衡"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAELCAYAAAARNxsIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAG71JREFUeJzt3X2QHXWd7/H3J+FBfEyQgcomZIOaXUFdo44QZWsX0QsB7xqkQKFciSx1I1ZQKL2W4N1LVh5213WVK7vK3mwRCa5LQESJGoxZBLkgTwHDQ0SKEVkZyYVgwpNcoQif+0f/Bk4mZ2aaoc+cHOfzquqa7m//us/3cISv3b9f/1q2iYiIaMKUbicQERG/P1JUIiKiMSkqERHRmBSViIhoTIpKREQ0JkUlIiIak6ISERGNSVGJiIjGpKhERERjdup2AhNtjz328Jw5c7qdRkRET7nlllsett03VrtJV1TmzJnDunXrup1GRERPkfSfddrl9ldERDSmY0VF0ksk3STpNkkbJH2uxC+Q9EtJ68syr8Ql6VxJA5Jul/TWlnMtknRPWRa1xN8m6Y5yzLmS1KnvExERY+vk7a+ngINtPyFpZ+BaSVeUfZ+2femw9ocBc8tyAHAecICk3YGlQD9g4BZJq2xvKW0WAzcAq4EFwBVERERXdOxKxZUnyubOZRltnv2FwIXluBuAaZJmAIcCa21vLoVkLbCg7Hul7etdzd9/IXBEp75PRESMraN9KpKmSloPPERVGG4su84ut7jOkbRric0E7m85fLDERosPtom3y2OxpHWS1m3atOlFf6+IiGivo0XF9lbb84BZwP6S3gicBrweeDuwO/CZ0rxdf4jHEW+XxzLb/bb7+/rGHBEXERHjNCGjv2w/AlwNLLC9sdziegr4GrB/aTYI7N1y2CzggTHis9rEIyKiSzo5+qtP0rSyvhvwHuDnpS+EMlLrCODOcsgq4LgyCmw+8KjtjcAa4BBJ0yVNBw4B1pR9j0uaX851HHB5p75PRESMrZOjv2YAKyRNpSpel9j+nqQfSeqjun21HjixtF8NHA4MAE8CxwPY3izpTODm0u4M25vL+seAC4DdqEZ9ZeRXREQXqRo4NXn09/c7T9RH7NgO/KcDu53C773rPn7dC2ov6Rbb/WO1yxP1ERHRmBSViIhoTIpKREQ0JkUlIiIak6ISERGNSVGJiIjGpKhERERjUlQiIqIxKSoREdGYFJWIiGhMikpERDQmRSUiIhqTohIREY1JUYmIiMakqERERGNSVCIiojEpKhER0ZgUlYiIaEyKSkRENCZFJSIiGpOiEhERjelYUZH0Ekk3SbpN0gZJnyvxfSTdKOkeSRdL2qXEdy3bA2X/nJZznVbid0s6tCW+oMQGJJ3aqe8SERH1dPJK5SngYNtvBuYBCyTNBz4PnGN7LrAFOKG0PwHYYvt1wDmlHZL2A44B3gAsAL4qaaqkqcBXgMOA/YBjS9uIiOiSjhUVV54omzuXxcDBwKUlvgI4oqwvLNuU/e+WpBJfafsp278EBoD9yzJg+17bTwMrS9uIiOiSjvaplCuK9cBDwFrgF8Ajtp8pTQaBmWV9JnA/QNn/KPDq1viwY0aKR0REl3S0qNjeanseMIvqymLfds3KX42w74XGtyNpsaR1ktZt2rRp7MQjImJcJmT0l+1HgKuB+cA0STuVXbOAB8r6ILA3QNn/KmBza3zYMSPF233+Mtv9tvv7+vqa+EoREdFGJ0d/9UmaVtZ3A94D3AVcBRxVmi0CLi/rq8o2Zf+PbLvEjymjw/YB5gI3ATcDc8tosl2oOvNXder7RETE2HYau8m4zQBWlFFaU4BLbH9P0s+AlZLOAn4KnF/anw98XdIA1RXKMQC2N0i6BPgZ8AywxPZWAEknAWuAqcBy2xs6+H0iImIMHSsqtm8H3tImfi9V/8rw+O+Ao0c419nA2W3iq4HVLzrZiIhoRJ6oj4iIxqSoREREY1JUIiKiMSkqERHRmBSViIhoTIpKREQ0JkUlIiIak6ISERGNSVGJiIjGpKhERERjxiwqkl4maUpZ/yNJ75O0c+dTi4iIXlPnSuUa4CWSZgJXAscDF3QyqYiI6E11iopsPwkcCfyT7fdTvRM+IiJiG7WKiqR3AB8Cvl9inZwyPyIielSdonIKcBrw7fJuk9dQvWgrIiJiG2Necdj+MfBjSS8r2/cCn+h0YhER0XvqjP56R3lb411l+82SvtrxzCIioufUuf31v4BDgd8A2L4N+LNOJhUREb2p1sOPtu8fFtragVwiIqLH1RnFdb+kdwKWtAtVf8pdnU0rIiJ6UZ0rlROBJcBMYBCYV7YjIiK2MWZRsf2w7Q/Z3sv2nrb/0vZvxjpO0t6SrpJ0l6QNkk4u8b+R9GtJ68tyeMsxp0kakHS3pENb4gtKbEDSqS3xfSTdKOkeSReXK6mIiOiSOqO/Vkia1rI9XdLyGud+BviU7X2B+cASSUNP4p9je15ZVpfz7gccA7wBWAB8VdJUSVOBrwCHUT3Jf2zLeT5fzjUX2AKcUCOviIjokDq3v/7E9iNDG7a3AG8Z6yDbG23fWtYfp+qHmTnKIQuBlbafsv1LYADYvywDtu+1/TSwElgoScDBwKXl+BXAETW+T0REdEidojJF0vShDUm78wKnaZE0h6oQ3VhCJ0m6XdLylnPPBFpHmQ2W2EjxVwOP2H5mWDwiIrqkTlH5IvATSWdKOhP4CfAPdT9A0suBbwGn2H4MOA94LVWH/8ZyfgC1OdzjiLfLYbGkdZLWbdq0qW7qERHxAtXpqL8QOAp4EHgIONL21+ucvLx35VvAN2xfVs73oO2ttp8F/pXq9hZUVxp7txw+C3hglPjDwDRJOw2Lt/sOy2z32+7v6+urk3pERIxD3Tc//hy4DLgceELS7LEOKH0e5wN32f5SS3xGS7P3A3eW9VXAMZJ2lbQPMBe4CbgZmFtGeu1C1Zm/yrapJrY8qhy/qOQXERFdMmbfiKSPA0uprlS2Ut12MvAnYxx6IPBh4A5J60vss1Sjt+aVc9wHfBSgzIB8CfAzqpFjS2xvLTmcBKwBpgLLbW8o5/sMsFLSWcBPqYpYRER0SZ0O95OBP67zbEor29fSvt9j9SjHnA2c3Sa+ut1xZcbk/YfHIyKiO+rc/rofeLTTiURERO+rc6VyL3C1pO8DTw0FW/tJIiIioF5R+VVZdilLREREW3Xe/Pg5AEkvs/3bzqcUERG9Km9+jIiIxuTNjxER0Zi8+TEiIhqTNz9GRERj8ubHiIhozKhXKuUFWR+2/aEJyiciInrYqFcqZe6thROUS0RE9Lg6fSrXSfpn4GLguedUht7qGBERMaROUXln+XtGS8xUr/KNiIh4zlh9KlOA82xfMkH5REREDxurT+VZ4KQJyiUiInpcnSHFayX9d0l7S9p9aOl4ZhER0XPq9Kn8Vfnb+myKgdc0n05ERPSyOrMU7zMRiURERO+r847649rFbV/YfDoREdHL6tz+envL+kuAdwO3AikqERGxjTq3vz7eui3pVcDXO5ZRRET0rFpT3w/zJDB3rEZltNhVku6StEHSySW+u6S1ku4pf6eXuCSdK2lA0u2S3tpyrkWl/T2SFrXE3ybpjnLMuZI0ju8TERENqfPmx+9KWlWW7wF3A5fXOPczwKds7wvMB5ZI2g84FbjS9lzgyrINcBhVsZoLLAbOK5+/O7AUOADYH1g6VIhKm8Utxy2okVdERHRInT6Vf2xZfwb4T9uDYx1keyOwsaw/LukuqunzFwIHlWYrgKuBz5T4hbYN3CBpmqQZpe1a25sBJK0FFki6Gnil7etL/ELgCOCKGt8pIiI6oE5R+RWw0fbvACTtJmmO7fvqfoikOcBbgBuBvUrBwfZGSXuWZjOB1jdMDpbYaPHBNvGIiOiSOn0q3wSebdneWmK1SHo58C3gFNuPjda0TczjiLfLYbGkdZLWbdq0aayUIyJinOoUlZ1sPz20UdZ3qXNySTtTFZRv2L6shB8st7Uofx8q8UFg75bDZwEPjBGf1Sa+HdvLbPfb7u/r66uTekREjEOdorJJ0vuGNiQtBB4e66AyEut84C7bX2rZtQoYGsG1iOc7/VcBx5VRYPOBR8ttsjXAIZKmlw76Q4A1Zd/jkuaXzzqOegMIIiKiQ+r0qZwIfKO8qAuqK4S2T9kPcyDwYeAOSetL7LPA3wOXSDqBqr/m6LJvNXA4MEA1bPl4ANubJZ0J3FzanTHUaQ98DLgA2I2qgz6d9BERXVTn4cdfAPNL34hsP17nxLavpX2/B1RP5Q9vb7adtLJ133JgeZv4OuCNdfKJiIjOq/Ocyt9Kmmb7iTI0eLqksyYiuYiI6C11+lQOs/3I0IbtLVS3qSIiIrZRp6hMlbTr0Iak3YBdR2kfERGTVJ2O+n8DrpT0NarnQP6K6kn4iIiIbdTpqP8HSbcD7ymhM22v6WxaERHRi+pcqQD8FNiZ6krlp51LJyIielmd0V8fAG4CjgI+ANwo6ahOJxYREb2nzpXK/wDebvshAEl9wH8Al3YysYiI6D11Rn9NGSooxW9qHhcREZNMnSuVH0haA1xUtj9INaVKRETENuqM/vq0pCOBP6WadmWZ7W93PLOIiOg5tUZ/lWnrLxuzYURETGrpG4mIiMakqERERGNGLCqSrix/Pz9x6URERC8brU9lhqQ/B94naSXD3o1i+9aOZhYRET1ntKJyOnAq1bvfvzRsn4GDO5VURET0phGLiu1LgUsl/U/bZ05gThER0aPqPKdypqT3AX9WQlfb/l5n04qIiF5UZ0LJvwNOBn5WlpNLLCIiYht1Hn58LzDP9rMAklZQTX9/WicTi4iI3lP3OZVpLeuv6kQiERHR++oUlb8DfirpgnKVcgvwt2MdJGm5pIck3dkS+xtJv5a0viyHt+w7TdKApLslHdoSX1BiA5JObYnvI+lGSfdIuljSLnW/dEREdMaYRcX2RcB8qrm/LgPeYXtljXNfACxoEz/H9ryyrAaQtB9wDPCGcsxXJU2VNBX4CnAYsB9wbGkL8PlyrrnAFuCEGjlFREQH1br9ZXuj7VW2L7f9f2secw2wuWYeC4GVtp+y/UtgANi/LAO277X9NLASWChJVM/JDL0obAVwRM3PioiIDunG3F8nSbq93B6bXmIzgftb2gyW2EjxVwOP2H5mWLwtSYslrZO0btOmTU19j4iIGGaii8p5wGuBecBG4IslrjZtPY54W7aX2e633d/X1/fCMo6IiNpGLSqSprR2tL9Yth+0vbUMT/5XqttbUF1p7N3SdBbwwCjxh4FpknYaFo+IiC4ataiU//jfJml2Ex8maUbL5vuBoYK1CjhG0q6S9gHmAjcBNwNzy0ivXag681fZNnAVcFQ5fhFweRM5RkTE+NV5+HEGsEHSTcBvh4K23zfaQZIuAg4C9pA0CCwFDpI0j+pW1X3AR8u5Nki6hOqJ/WeAJba3lvOcBKwBpgLLbW8oH/EZYKWks6gexjy/zheOiIjOqVNUPjeeE9s+tk14xP/w2z4bOLtNfDWwuk38Xp6/fRYRETuAOhNK/ljSHwJzbf+HpJdSXTVERERso86Ekv+N6nmQ/11CM4HvdDKpiIjoTXWGFC8BDgQeA7B9D7BnJ5OKiIjeVKeoPFWeZgegDOMd8ZmQiIiYvOoUlR9L+iywm6T/AnwT+G5n04qIiF5Up6icCmwC7qAaArwa+OtOJhUREb2pzuivZ8uU9zdS3fa6uzx8GBERsY0xi4qk9wL/AvyCas6tfSR91PYVnU4uIiJ6S52HH78IvMv2AICk1wLfB1JUIiJiG3X6VB4aKijFvcBDHconIiJ62IhXKpKOLKsbJK0GLqHqUzmaaqLHiIiIbYx2++svWtYfBP68rG8Cpm/fPCIiJrsRi4rt4ycykYiI6H11Rn/tA3wcmNPafqyp7yMiYvKpM/rrO1RT1n8XeLaz6URERC+rU1R+Z/vcjmcSERE9r05R+bKkpcAPgaeGgrZv7VhWERHRk+oUlTcBHwYO5vnbXy7bETukX53xpm6nMCnMPv2ObqcQO5g6ReX9wGtap7+PiIhop84T9bcB0zqdSERE9L46Vyp7AT+XdDPb9qlkSHFERGyjTlFZOp4TS1oO/FequcPeWGK7AxdTPfNyH/AB21skCfgycDjwJPCRoYEAkhbx/PtbzrK9osTfBlwA7Eb1jpeTMyV/RER3jXn7y/aP2y01zn0BsGBY7FTgSttzgSvLNsBhwNyyLAbOg+eK0FLgAGB/YKmkoSlizitth44b/lkRETHBxiwqkh6X9FhZfidpq6THxjrO9jXA5mHhhcCKsr4COKIlfqErNwDTJM0ADgXW2t5sewuwFlhQ9r3S9vXl6uTClnNFRESX1Hnz4ytatyUdQXXVMB572d5YzrtR0p4lPhO4v6XdYImNFh9sE29L0mKqqxpmz549ztQjImIsdUZ/bcP2d2j+GRW1+6hxxNuyvcx2v+3+vr6+caYYERFjqTOh5JEtm1OAfkb5D/gYHpQ0o1ylzOD5l30NAnu3tJsFPFDiBw2LX13is9q0j4iILqpzpfIXLcuhwONUfSDjsQpYVNYXAZe3xI9TZT7waLlNtgY4RNL00kF/CLCm7Htc0vwycuy4lnNFRESX1OlTGdd7VSRdRHWVsYekQapRXH8PXCLpBOBXVG+RhGpI8OHAANWQ4uPLZ2+WdCbPv2nyDNtDnf8f4/khxVeUJSIiumi01wmfPspxtn3maCe2fewIu97d7mTAkhHOsxxY3ia+DnjjaDlERMTEGu1K5bdtYi8DTgBeDYxaVCIiYvIZ7XXCXxxal/QK4GSq21IrgS+OdFxERExeo/aplCfaPwl8iOphxbeWhxAjIiK2M1qfyheAI4FlwJtsPzFhWUVERE8abUjxp4A/oJrM8YGWqVoerzNNS0RETD6j9am84KftIyJickvhiIiIxqSoREREY1JUIiKiMSkqERHRmBSViIhoTIpKREQ0JkUlIiIak6ISERGNSVGJiIjGpKhERERjUlQiIqIxKSoREdGYFJWIiGhMikpERDQmRSUiIhrTlaIi6T5Jd0haL2ldie0uaa2ke8rf6SUuSedKGpB0u6S3tpxnUWl/j6RF3fguERHxvG5eqbzL9jzb/WX7VOBK23OBK8s2wGHA3LIsBs6DqggBS4EDgP2BpUOFKCIiumNHuv21EFhR1lcAR7TEL3TlBmCapBnAocBa25ttbwHWAgsmOumIiHhet4qKgR9KukXS4hLby/ZGgPJ3zxKfCdzfcuxgiY0Uj4iILhnxHfUddqDtByTtCayV9PNR2qpNzKPEtz9BVbgWA8yePfuF5hoRETV15UrF9gPl70PAt6n6RB4st7Uofx8qzQeBvVsOnwU8MEq83ects91vu7+vr6/JrxIRES0mvKhIepmkVwytA4cAdwKrgKERXIuAy8v6KuC4MgpsPvBouT22BjhE0vTSQX9IiUVERJd04/bXXsC3JQ19/r/b/oGkm4FLJJ0A/Ao4urRfDRwODABPAscD2N4s6Uzg5tLuDNubJ+5rRETEcBNeVGzfC7y5Tfw3wLvbxA0sGeFcy4HlTecYERHjsyMNKY6IiB6XohIREY3p1pDinvC2T1/Y7RR+793yheO6nUJENChXKhER0ZgUlYiIaEyKSkRENCZFJSIiGpOiEhERjUlRiYiIxqSoREREY1JUIiKiMSkqERHRmBSViIhoTIpKREQ0JkUlIiIak6ISERGNSVGJiIjGpKhERERjUlQiIqIxKSoREdGYFJWIiGhMzxcVSQsk3S1pQNKp3c4nImIy6+miImkq8BXgMGA/4FhJ+3U3q4iIyauniwqwPzBg+17bTwMrgYVdzikiYtLq9aIyE7i/ZXuwxCIiogt26nYCL5LaxLxdI2kxsLhsPiHp7o5m1V17AA93O4m69I+Lup3CjqSnfjsAlrb7V3DS6qnfT594wb/dH9Zp1OtFZRDYu2V7FvDA8Ea2lwHLJiqpbpK0znZ/t/OIFy6/XW/L71fp9dtfNwNzJe0jaRfgGGBVl3OKiJi0evpKxfYzkk4C1gBTgeW2N3Q5rYiISauniwqA7dXA6m7nsQOZFLf5fk/lt+tt+f0A2dv1a0dERIxLr/epRETEDiRFpQeNNTWNpF0lXVz23yhpzsRnGe1IWi7pIUl3jrBfks4tv93tkt460TnGyCTtLekqSXdJ2iDp5DZtJvVvmKLSY2pOTXMCsMX264BzgM9PbJYxiguABaPsPwyYW5bFwHkTkFPU9wzwKdv7AvOBJW3+/ZvUv2GKSu+pMzXNQmBFWb8UeLekPKW2A7B9DbB5lCYLgQtduQGYJmnGxGQXY7G90fatZf1x4C62n8VjUv+GKSq9p87UNM+1sf0M8Cjw6gnJLl6sTD3UI8pt5bcANw7bNal/wxSV3lNnappa09fEDim/XQ+Q9HLgW8Apth8bvrvNIZPmN0xR6T11pqZ5ro2knYBXMfotl9hx1Jp6KLpH0s5UBeUbti9r02RS/4YpKr2nztQ0q4ChmRqPAn7kPJDUK1YBx5URRPOBR21v7HZSUSl9k+cDd9n+0gjNJvVv2PNP1E82I01NI+kMYJ3tVVT/o/+6pAGqK5RjupdxtJJ0EXAQsIekQWApsDOA7X+hmh3icGAAeBI4vjuZxggOBD4M3CFpfYl9FpgN+Q0hT9RHRESDcvsrIiIak6ISERGNSVGJiIjGpKhERERjUlQiIqIxKSoREdGYFJUIQNJParQ5RdJLO5zHPEmHj9HmI5L+ueHPbfycMTmlqEQAtt9Zo9kpwAsqKuVVBS/EPKoH5yJ6UopKBCDpifL3IElXS7pU0s8lfaNMt/EJ4A+AqyRdVdoeIul6SbdK+maZZBBJ90k6XdK1wNGSXivpB5JukfR/JL2+tDta0p2SbpN0TZl25wzgg5LWS/pgjbz7JH1L0s1lOVDSlJLDtJZ2A5L2ate+8X+YMallmpaI7b0FeAPVJIDXAQfaPlfSJ4F32X5Y0h7AXwPvsf1bSZ8BPklVFAB+Z/tPASRdCZxo+x5JBwBfBQ4GTgcOtf1rSdNsPy3pdKDf9kk1c/0ycI7tayXNBtbY3lfS5cD7ga+Vz7zP9oOS/n14e2DfF/nPK+I5KSoR27vJ9iBAmd9pDnDtsDbzqd68eV15/9kuwPUt+y8ux78ceCfwzZb3pO1a/l4HXCDpEqDdbLd1vAfYr+Xcr5T0ivL5pwNfo5r77eIx2kc0IkUlYntPtaxvpf2/JwLW2j52hHP8tvydAjxie97wBrZPLFcR7wXWS9quTQ1TgHfY/n/bJCddD7xOUh9wBHDWGO3H8dER20ufSkR9jwND/6/+BuBASa8DkPRSSX80/IDyAqdfSjq6tJOkN5f119q+0fbpwMNU7+Bo/Yw6fgg8d6tsqDCVVx18G/gS1TTtvxmtfURTUlQi6lsGXCHpKtubgI8AF0m6narIvH6E4z4EnCDpNmAD1TvMAb4g6Q5JdwLXALcBV1HdnqrVUQ98AuiXdLuknwEntuy7GPhLnr/1NVb7iBctU99HRERjcqUSERGNSUd9xA5K0vHAycPC19le0o18IurI7a+IiGhMbn9FRERjUlQiIqIxKSoREdGYFJWIiGhMikpERDTm/wMOO7fQmnpMBgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f5d3745a978>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "sns.countplot(dtrain.get_label());\n",
    "pyplot.xlabel('interest_level');\n",
    "pyplot.ylabel('Number of occurrences');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "每类样本分布不是很均匀"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1.初步确定弱学习器数目"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#设置交叉验证参数，具体含义不是很理解\n",
    "kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'base_score': 0.5, 'booster': 'gbtree', 'colsample_bylevel': 0.7, 'colsample_bytree': 0.8, 'gamma': 0, 'learning_rate': 0.1, 'max_delta_step': 0, 'max_depth': 5, 'min_child_weight': 1, 'missing': None, 'n_estimators': 1000, 'objective': 'multi:softprob', 'reg_alpha': 0, 'reg_lambda': 1, 'scale_pos_weight': 1, 'seed': 3, 'silent': 1, 'subsample': 0.3, 'num_class': 3}\n",
      "     test-mlogloss-mean  test-mlogloss-std  train-mlogloss-mean  \\\n",
      "0              1.039879           0.000241             1.039046   \n",
      "1              0.990196           0.000398             0.988831   \n",
      "2              0.948090           0.000464             0.946050   \n",
      "3              0.912131           0.000303             0.909602   \n",
      "4              0.880770           0.000295             0.877685   \n",
      "5              0.853282           0.000830             0.849719   \n",
      "6              0.829310           0.000929             0.825136   \n",
      "7              0.808341           0.000936             0.803658   \n",
      "8              0.790273           0.000830             0.784936   \n",
      "9              0.773685           0.000827             0.767779   \n",
      "10             0.759053           0.000927             0.752584   \n",
      "11             0.746137           0.001010             0.739088   \n",
      "12             0.734827           0.001182             0.727321   \n",
      "13             0.724438           0.001262             0.716524   \n",
      "14             0.715292           0.001385             0.707011   \n",
      "15             0.707090           0.001374             0.698397   \n",
      "16             0.699655           0.001400             0.690538   \n",
      "17             0.693071           0.001553             0.683374   \n",
      "18             0.687009           0.001558             0.676746   \n",
      "19             0.681585           0.001320             0.670823   \n",
      "20             0.677051           0.001437             0.665799   \n",
      "21             0.672457           0.001416             0.660835   \n",
      "22             0.668392           0.001253             0.656266   \n",
      "23             0.664783           0.001248             0.652065   \n",
      "24             0.661161           0.001226             0.648048   \n",
      "25             0.657922           0.001443             0.644344   \n",
      "26             0.654968           0.001503             0.640967   \n",
      "27             0.652184           0.001418             0.637668   \n",
      "28             0.649794           0.001520             0.634833   \n",
      "29             0.647531           0.001597             0.632101   \n",
      "..                  ...                ...                  ...   \n",
      "272            0.590017           0.002041             0.481678   \n",
      "273            0.590023           0.002106             0.481319   \n",
      "274            0.589995           0.002061             0.481034   \n",
      "275            0.589978           0.002096             0.480661   \n",
      "276            0.589970           0.002077             0.480343   \n",
      "277            0.589988           0.002081             0.479972   \n",
      "278            0.590021           0.002054             0.479667   \n",
      "279            0.589960           0.002046             0.479359   \n",
      "280            0.590020           0.002118             0.478990   \n",
      "281            0.589984           0.002109             0.478682   \n",
      "282            0.589919           0.002139             0.478343   \n",
      "283            0.589980           0.002125             0.478117   \n",
      "284            0.589971           0.002098             0.477775   \n",
      "285            0.589963           0.002099             0.477511   \n",
      "286            0.590016           0.002130             0.477137   \n",
      "287            0.589982           0.002174             0.476812   \n",
      "288            0.589977           0.002146             0.476533   \n",
      "289            0.590021           0.002134             0.476217   \n",
      "290            0.590060           0.002170             0.475865   \n",
      "291            0.589965           0.002106             0.475511   \n",
      "292            0.589971           0.002102             0.475138   \n",
      "293            0.589984           0.002050             0.474846   \n",
      "294            0.589942           0.002087             0.474516   \n",
      "295            0.589919           0.002174             0.474261   \n",
      "296            0.589919           0.002144             0.473940   \n",
      "297            0.589907           0.002140             0.473618   \n",
      "298            0.589899           0.002123             0.473338   \n",
      "299            0.589888           0.002152             0.472983   \n",
      "300            0.589861           0.002245             0.472661   \n",
      "301            0.589839           0.002211             0.472333   \n",
      "\n",
      "     train-mlogloss-std  \n",
      "0              0.000419  \n",
      "1              0.000522  \n",
      "2              0.000165  \n",
      "3              0.000681  \n",
      "4              0.000717  \n",
      "5              0.000231  \n",
      "6              0.000095  \n",
      "7              0.000260  \n",
      "8              0.000806  \n",
      "9              0.000830  \n",
      "10             0.001036  \n",
      "11             0.001137  \n",
      "12             0.001210  \n",
      "13             0.001043  \n",
      "14             0.001064  \n",
      "15             0.001228  \n",
      "16             0.001128  \n",
      "17             0.001165  \n",
      "18             0.001231  \n",
      "19             0.001395  \n",
      "20             0.001258  \n",
      "21             0.001302  \n",
      "22             0.001615  \n",
      "23             0.001785  \n",
      "24             0.001823  \n",
      "25             0.001512  \n",
      "26             0.001340  \n",
      "27             0.001508  \n",
      "28             0.001503  \n",
      "29             0.001247  \n",
      "..                  ...  \n",
      "272            0.001040  \n",
      "273            0.000967  \n",
      "274            0.000963  \n",
      "275            0.000960  \n",
      "276            0.000949  \n",
      "277            0.000883  \n",
      "278            0.000885  \n",
      "279            0.000874  \n",
      "280            0.000832  \n",
      "281            0.000807  \n",
      "282            0.000848  \n",
      "283            0.000823  \n",
      "284            0.000800  \n",
      "285            0.000777  \n",
      "286            0.000769  \n",
      "287            0.000766  \n",
      "288            0.000785  \n",
      "289            0.000761  \n",
      "290            0.000726  \n",
      "291            0.000747  \n",
      "292            0.000810  \n",
      "293            0.000812  \n",
      "294            0.000796  \n",
      "295            0.000792  \n",
      "296            0.000744  \n",
      "297            0.000789  \n",
      "298            0.000812  \n",
      "299            0.000812  \n",
      "300            0.000760  \n",
      "301            0.000791  \n",
      "\n",
      "[302 rows x 4 columns]\n",
      "302\n"
     ]
    }
   ],
   "source": [
    "#第一次调试，使用一些粗参数，了解下整体情况\n",
    "xgb_1 = XGBClassifier(\n",
    "        learning_rate =0.1,\n",
    "        n_estimators=1000,  #数值大没关系，cv会自动返回合适的n_estimators\n",
    "        max_depth=5,\n",
    "        min_child_weight=1,\n",
    "        gamma=0,\n",
    "        subsample=0.3,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective= 'multi:softprob',\n",
    "        nthread=-1,\n",
    "        seed=3)\n",
    "xgb_param = xgb_1.get_xgb_params()\n",
    "xgb_param['num_class'] = 3\n",
    "print(xgb_param)\n",
    "#此处是可以不使用XGBClassifier而直接设置param的，但是高度发现如果不使用XGBClassifier，在运行cv时cpu只使用1个核\n",
    "#而使用XGBClassifier并设置nthread=-1之后运行时会使用所有核，很奇怪。为了节省时间而这样做\n",
    "cvresult = xgb.cv(xgb_param, \n",
    "            dtrain, \n",
    "            num_boost_round=1000, \n",
    "            folds =kfold,\n",
    "            metrics='mlogloss', \n",
    "            early_stopping_rounds=100)\n",
    "#将得到的数据保存下来\n",
    "cvresult.to_csv('1_nestimators.csv', index_label = 'n_estimators')\n",
    "n_estimators = cvresult.shape[0]\n",
    "print(cvresult)\n",
    "print(n_estimators)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2.对树的最大深度max_depth和min_children_weight进行调优"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'max_depth': range(3, 10, 2), 'min_child_weight': range(1, 6, 2)}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#此处参考课程案例中的参数进行粗调\n",
    "max_depth = range(3,10,2)\n",
    "min_child_weight = range(1,6,2)\n",
    "param_2 = dict(max_depth=max_depth, min_child_weight=min_child_weight)\n",
    "param_2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ai/tool/bin/anaconda3/lib/python3.6/site-packages/sklearn/model_selection/_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.59698, std: 0.00305, params: {'max_depth': 3, 'min_child_weight': 1},\n",
       "  mean: -0.59674, std: 0.00325, params: {'max_depth': 3, 'min_child_weight': 3},\n",
       "  mean: -0.59669, std: 0.00311, params: {'max_depth': 3, 'min_child_weight': 5},\n",
       "  mean: -0.58828, std: 0.00391, params: {'max_depth': 5, 'min_child_weight': 1},\n",
       "  mean: -0.58824, std: 0.00394, params: {'max_depth': 5, 'min_child_weight': 3},\n",
       "  mean: -0.58800, std: 0.00328, params: {'max_depth': 5, 'min_child_weight': 5},\n",
       "  mean: -0.59361, std: 0.00521, params: {'max_depth': 7, 'min_child_weight': 1},\n",
       "  mean: -0.59163, std: 0.00504, params: {'max_depth': 7, 'min_child_weight': 3},\n",
       "  mean: -0.59086, std: 0.00438, params: {'max_depth': 7, 'min_child_weight': 5},\n",
       "  mean: -0.61489, std: 0.00578, params: {'max_depth': 9, 'min_child_weight': 1},\n",
       "  mean: -0.60627, std: 0.00461, params: {'max_depth': 9, 'min_child_weight': 3},\n",
       "  mean: -0.60226, std: 0.00384, params: {'max_depth': 9, 'min_child_weight': 5}],\n",
       " {'max_depth': 5, 'min_child_weight': 5},\n",
       " -0.5880016841361746)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#将n_estimators改为上面测试得到的值\n",
    "xgb_2 = XGBClassifier(\n",
    "        learning_rate =0.1,\n",
    "        n_estimators=302,  #第一轮参数调整得到的n_estimators最优值\n",
    "        max_depth=5,\n",
    "        min_child_weight=1,\n",
    "        gamma=0,\n",
    "        subsample=0.3,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel = 0.7,\n",
    "        objective= 'multi:softprob',\n",
    "        nthread=-1,\n",
    "        seed=3)\n",
    "gsearch_2 = GridSearchCV(xgb_2, param_grid = param_2, scoring='neg_log_loss',n_jobs=-1, cv=kfold, return_train_score=True)\n",
    "gsearch_2.fit(X_train , y_train)\n",
    "\n",
    "gsearch_2.grid_scores_, gsearch_2.best_params_, gsearch_2.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'mean_fit_time': array([129.20172372, 126.60045013, 127.27998028, 198.20038252,\n",
       "        201.81758571, 196.96690502, 276.29535966, 272.54735055,\n",
       "        273.02435198, 347.29126482, 338.62546539, 268.21926589]),\n",
       " 'mean_score_time': array([0.53267541, 0.54582825, 0.52477474, 0.87576437, 0.85085826,\n",
       "        0.85391679, 3.13586025, 2.99922571, 2.52431059, 6.70815415,\n",
       "        4.82494931, 1.18751459]),\n",
       " 'mean_test_score': array([-0.59698263, -0.59673646, -0.59668512, -0.58827646, -0.58824391,\n",
       "        -0.58800168, -0.59361108, -0.5916261 , -0.59085665, -0.61488994,\n",
       "        -0.60626731, -0.6022617 ]),\n",
       " 'mean_train_score': array([-0.5674826 , -0.56847758, -0.5689916 , -0.48925459, -0.49617301,\n",
       "        -0.50096409, -0.36657095, -0.39134523, -0.40797155, -0.2306495 ,\n",
       "        -0.28239318, -0.31361588]),\n",
       " 'param_max_depth': masked_array(data=[3, 3, 3, 5, 5, 5, 7, 7, 7, 9, 9, 9],\n",
       "              mask=[False, False, False, False, False, False, False, False,\n",
       "                    False, False, False, False],\n",
       "        fill_value='?',\n",
       "             dtype=object),\n",
       " 'param_min_child_weight': masked_array(data=[1, 3, 5, 1, 3, 5, 1, 3, 5, 1, 3, 5],\n",
       "              mask=[False, False, False, False, False, False, False, False,\n",
       "                    False, False, False, False],\n",
       "        fill_value='?',\n",
       "             dtype=object),\n",
       " 'params': [{'max_depth': 3, 'min_child_weight': 1},\n",
       "  {'max_depth': 3, 'min_child_weight': 3},\n",
       "  {'max_depth': 3, 'min_child_weight': 5},\n",
       "  {'max_depth': 5, 'min_child_weight': 1},\n",
       "  {'max_depth': 5, 'min_child_weight': 3},\n",
       "  {'max_depth': 5, 'min_child_weight': 5},\n",
       "  {'max_depth': 7, 'min_child_weight': 1},\n",
       "  {'max_depth': 7, 'min_child_weight': 3},\n",
       "  {'max_depth': 7, 'min_child_weight': 5},\n",
       "  {'max_depth': 9, 'min_child_weight': 1},\n",
       "  {'max_depth': 9, 'min_child_weight': 3},\n",
       "  {'max_depth': 9, 'min_child_weight': 5}],\n",
       " 'rank_test_score': array([ 9,  8,  7,  3,  2,  1,  6,  5,  4, 12, 11, 10], dtype=int32),\n",
       " 'split0_test_score': array([-0.59204832, -0.5910662 , -0.59173358, -0.58134611, -0.58212964,\n",
       "        -0.58264213, -0.58562341, -0.58218238, -0.58366399, -0.60369655,\n",
       "        -0.60046295, -0.59607697]),\n",
       " 'split0_train_score': array([-0.56896491, -0.57017341, -0.57027648, -0.49126119, -0.49812216,\n",
       "        -0.50258694, -0.36765255, -0.393789  , -0.40926838, -0.23092178,\n",
       "        -0.28315377, -0.31370155]),\n",
       " 'split1_test_score': array([-0.59564702, -0.59541672, -0.59491354, -0.58697021, -0.5852546 ,\n",
       "        -0.58608315, -0.5930449 , -0.59139829, -0.58951733, -0.61626541,\n",
       "        -0.60122823, -0.60184301]),\n",
       " 'split1_train_score': array([-0.56843727, -0.56947654, -0.56983048, -0.48901562, -0.49602313,\n",
       "        -0.5012681 , -0.3669019 , -0.39124287, -0.40760719, -0.23070625,\n",
       "        -0.28452229, -0.31320718]),\n",
       " 'split2_test_score': array([-0.59709061, -0.59786313, -0.59700977, -0.58952426, -0.58980654,\n",
       "        -0.58951103, -0.59268308, -0.59368529, -0.59118965, -0.61592373,\n",
       "        -0.60760948, -0.6025273 ]),\n",
       " 'split2_train_score': array([-0.56737275, -0.5682099 , -0.56860248, -0.48867472, -0.49586225,\n",
       "        -0.50060265, -0.3657479 , -0.39035741, -0.40799165, -0.23062096,\n",
       "        -0.28257993, -0.31545835]),\n",
       " 'split3_test_score': array([-0.60091961, -0.60017594, -0.59999024, -0.59114081, -0.59165383,\n",
       "        -0.58979344, -0.594773  , -0.59389748, -0.59287827, -0.61927957,\n",
       "        -0.61135069, -0.60267999]),\n",
       " 'split3_train_score': array([-0.56655181, -0.56752701, -0.56815855, -0.49026091, -0.49626713,\n",
       "        -0.50045636, -0.36540818, -0.39009982, -0.40695067, -0.23082524,\n",
       "        -0.28085445, -0.31175937]),\n",
       " 'split4_test_score': array([-0.59920829, -0.59916103, -0.59977943, -0.59240218, -0.59237621,\n",
       "        -0.59197987, -0.60193356, -0.59696866, -0.59703586, -0.61928578,\n",
       "        -0.61068654, -0.60818306]),\n",
       " 'split4_train_score': array([-0.56608625, -0.56700106, -0.56809   , -0.48706054, -0.4945904 ,\n",
       "        -0.49990641, -0.36714422, -0.39123707, -0.40803987, -0.23017326,\n",
       "        -0.28085544, -0.31395296]),\n",
       " 'std_fit_time': array([ 2.01544264,  2.43510594,  2.85158673,  2.38587125,  5.02303748,\n",
       "         0.76121915,  5.57746069,  5.97417331,  6.74741261,  2.37356911,\n",
       "         3.7838462 , 22.87506547]),\n",
       " 'std_score_time': array([0.02353768, 0.07196024, 0.03298554, 0.05454467, 0.03892465,\n",
       "        0.01697357, 0.60010335, 0.61598829, 0.37159121, 0.39377809,\n",
       "        1.12127709, 0.16035659]),\n",
       " 'std_test_score': array([0.00305279, 0.00325186, 0.00310677, 0.003912  , 0.00393612,\n",
       "        0.00327776, 0.00520741, 0.00504336, 0.00437956, 0.00577666,\n",
       "        0.00460989, 0.00384   ]),\n",
       " 'std_train_score': array([0.00108967, 0.00118567, 0.00089587, 0.00143152, 0.00113417,\n",
       "        0.00092013, 0.00085293, 0.00130545, 0.00075654, 0.00025918,\n",
       "        0.00140562, 0.0011941 ])}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#查看结果\n",
    "gsearch_2.cv_results_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'max_depth': [4, 5, 6], 'min_child_weight': [4, 5, 6]}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#根据上面得到的结果减少步进进行细调max_depth和min_child_weight\n",
    "max_depth = [4,5,6]\n",
    "min_child_weight = [4,5,6]\n",
    "param_3 = dict(max_depth=max_depth, min_child_weight=min_child_weight)\n",
    "param_3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ai/tool/bin/anaconda3/lib/python3.6/site-packages/sklearn/model_selection/_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.59063, std: 0.00394, params: {'max_depth': 4, 'min_child_weight': 4},\n",
       "  mean: -0.59084, std: 0.00386, params: {'max_depth': 4, 'min_child_weight': 5},\n",
       "  mean: -0.59061, std: 0.00409, params: {'max_depth': 4, 'min_child_weight': 6},\n",
       "  mean: -0.58792, std: 0.00515, params: {'max_depth': 5, 'min_child_weight': 4},\n",
       "  mean: -0.58800, std: 0.00328, params: {'max_depth': 5, 'min_child_weight': 5},\n",
       "  mean: -0.58892, std: 0.00375, params: {'max_depth': 5, 'min_child_weight': 6},\n",
       "  mean: -0.59052, std: 0.00366, params: {'max_depth': 6, 'min_child_weight': 4},\n",
       "  mean: -0.58981, std: 0.00422, params: {'max_depth': 6, 'min_child_weight': 5},\n",
       "  mean: -0.58923, std: 0.00354, params: {'max_depth': 6, 'min_child_weight': 6}],\n",
       " {'max_depth': 5, 'min_child_weight': 4},\n",
       " -0.5879152170783184)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#再次运行GridSearchCV进行验证\n",
    "xgb_3 = XGBClassifier(\n",
    "        learning_rate =0.1,\n",
    "        n_estimators=302,  #第一轮参数调整得到的n_estimators最优值\n",
    "        max_depth=5,\n",
    "        min_child_weight=1,\n",
    "        gamma=0,\n",
    "        subsample=0.3,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel = 0.7,\n",
    "        objective= 'multi:softprob',\n",
    "        nthread=-1,\n",
    "        seed=3)\n",
    "\n",
    "gsearch3 = GridSearchCV(xgb_3, param_grid = param_3, scoring='neg_log_loss',n_jobs=-1, cv=kfold, return_train_score=True)\n",
    "gsearch3.fit(X_train , y_train)\n",
    "\n",
    "gsearch3.grid_scores_, gsearch3.best_params_, gsearch3.best_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这次得到max_depth和min_child_weight最优解分别为5和4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3.重新调整弱学习器数目"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "     test-mlogloss-mean  test-mlogloss-std  train-mlogloss-mean  \\\n",
      "0              1.039928           0.000188             1.039138   \n",
      "1              0.990261           0.000337             0.988982   \n",
      "2              0.948213           0.000529             0.946266   \n",
      "3              0.912220           0.000235             0.909804   \n",
      "4              0.880873           0.000258             0.877942   \n",
      "5              0.853370           0.000867             0.849996   \n",
      "6              0.829386           0.000970             0.825378   \n",
      "7              0.808416           0.001007             0.803911   \n",
      "8              0.790323           0.000821             0.785246   \n",
      "9              0.773721           0.000803             0.768137   \n",
      "10             0.759039           0.000956             0.752976   \n",
      "11             0.746175           0.001010             0.739499   \n",
      "12             0.734878           0.001220             0.727799   \n",
      "13             0.724501           0.001290             0.717020   \n",
      "14             0.715308           0.001424             0.707508   \n",
      "15             0.707134           0.001409             0.698928   \n",
      "16             0.699713           0.001425             0.691063   \n",
      "17             0.693042           0.001538             0.683880   \n",
      "18             0.687028           0.001524             0.677271   \n",
      "19             0.681592           0.001262             0.671329   \n",
      "20             0.676916           0.001261             0.666280   \n",
      "21             0.672398           0.001296             0.661429   \n",
      "22             0.668282           0.001196             0.656917   \n",
      "23             0.664736           0.001285             0.652821   \n",
      "24             0.661196           0.001317             0.648902   \n",
      "25             0.658060           0.001520             0.645276   \n",
      "26             0.655133           0.001583             0.641976   \n",
      "27             0.652452           0.001603             0.638820   \n",
      "28             0.650056           0.001746             0.636010   \n",
      "29             0.647793           0.001790             0.633300   \n",
      "..                  ...                ...                  ...   \n",
      "272            0.590293           0.002047             0.492416   \n",
      "273            0.590297           0.002076             0.492075   \n",
      "274            0.590311           0.002081             0.491869   \n",
      "275            0.590348           0.002134             0.491567   \n",
      "276            0.590319           0.002164             0.491320   \n",
      "277            0.590312           0.002151             0.490959   \n",
      "278            0.590281           0.002184             0.490653   \n",
      "279            0.590249           0.002115             0.490373   \n",
      "280            0.590242           0.002125             0.490081   \n",
      "281            0.590222           0.002147             0.489822   \n",
      "282            0.590193           0.002179             0.489557   \n",
      "283            0.590285           0.002168             0.489319   \n",
      "284            0.590286           0.002158             0.489041   \n",
      "285            0.590311           0.002146             0.488771   \n",
      "286            0.590381           0.002174             0.488457   \n",
      "287            0.590332           0.002151             0.488194   \n",
      "288            0.590407           0.002100             0.487917   \n",
      "289            0.590455           0.002059             0.487614   \n",
      "290            0.590431           0.002028             0.487304   \n",
      "291            0.590461           0.001966             0.487031   \n",
      "292            0.590476           0.001959             0.486719   \n",
      "293            0.590549           0.001974             0.486442   \n",
      "294            0.590482           0.002037             0.486114   \n",
      "295            0.590446           0.002060             0.485869   \n",
      "296            0.590415           0.002023             0.485565   \n",
      "297            0.590414           0.002037             0.485295   \n",
      "298            0.590356           0.002004             0.485012   \n",
      "299            0.590318           0.002067             0.484676   \n",
      "300            0.590287           0.002095             0.484399   \n",
      "301            0.590268           0.002117             0.484118   \n",
      "\n",
      "     train-mlogloss-std  \n",
      "0              0.000404  \n",
      "1              0.000479  \n",
      "2              0.000205  \n",
      "3              0.000693  \n",
      "4              0.000718  \n",
      "5              0.000187  \n",
      "6              0.000190  \n",
      "7              0.000328  \n",
      "8              0.000878  \n",
      "9              0.000895  \n",
      "10             0.001101  \n",
      "11             0.001211  \n",
      "12             0.001187  \n",
      "13             0.001074  \n",
      "14             0.001130  \n",
      "15             0.001308  \n",
      "16             0.001178  \n",
      "17             0.001229  \n",
      "18             0.001286  \n",
      "19             0.001433  \n",
      "20             0.001411  \n",
      "21             0.001483  \n",
      "22             0.001721  \n",
      "23             0.001842  \n",
      "24             0.001977  \n",
      "25             0.001802  \n",
      "26             0.001664  \n",
      "27             0.001742  \n",
      "28             0.001768  \n",
      "29             0.001524  \n",
      "..                  ...  \n",
      "272            0.001042  \n",
      "273            0.000993  \n",
      "274            0.000944  \n",
      "275            0.000923  \n",
      "276            0.000893  \n",
      "277            0.000853  \n",
      "278            0.000836  \n",
      "279            0.000863  \n",
      "280            0.000829  \n",
      "281            0.000804  \n",
      "282            0.000831  \n",
      "283            0.000850  \n",
      "284            0.000863  \n",
      "285            0.000880  \n",
      "286            0.000930  \n",
      "287            0.000981  \n",
      "288            0.001030  \n",
      "289            0.000986  \n",
      "290            0.000976  \n",
      "291            0.000997  \n",
      "292            0.001032  \n",
      "293            0.001012  \n",
      "294            0.001058  \n",
      "295            0.001054  \n",
      "296            0.001011  \n",
      "297            0.000997  \n",
      "298            0.000926  \n",
      "299            0.000946  \n",
      "300            0.000928  \n",
      "301            0.000963  \n",
      "\n",
      "[302 rows x 4 columns]\n",
      "302\n"
     ]
    }
   ],
   "source": [
    "xgb_4 = XGBClassifier(\n",
    "        learning_rate =0.1,\n",
    "        n_estimators=302,\n",
    "        max_depth=5,\n",
    "        min_child_weight=4,\n",
    "        gamma=0,\n",
    "        subsample=0.3,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective= 'multi:softprob',\n",
    "        nthread=-1,\n",
    "        seed=3)\n",
    "xgb_param = xgb_4.get_xgb_params()\n",
    "xgb_param['num_class'] = 3\n",
    "cvresult = xgb.cv(xgb_param, \n",
    "            dtrain, \n",
    "            num_boost_round=302, \n",
    "            folds =kfold,\n",
    "            metrics='mlogloss', \n",
    "            early_stopping_rounds=100)\n",
    "#将得到的数据保存下来\n",
    "cvresult.to_csv('2_nestimators.csv', index_label = 'n_estimators')\n",
    "n_estimators = cvresult.shape[0]\n",
    "print(cvresult)\n",
    "print(n_estimators)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "重新调整出来的n_estimators并没有减少，比较奇怪\n",
    "同时将上面num_boost_round=302改为1000试过，结果为334,比原来还大，不太明白，待老师讲解"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4.行列重采样参数调整"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'colsample_bytree': [0.6, 0.7, 0.8, 0.9],\n",
       " 'subsample': [0.3, 0.4, 0.5, 0.6, 0.7, 0.8]}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#此处参数用命课程里的参考值\n",
    "subsample = [i/10.0 for i in range(3,9)]\n",
    "colsample_bytree = [i/10.0 for i in range(6,10)]\n",
    "param_5 = dict(subsample=subsample, colsample_bytree=colsample_bytree)\n",
    "param_5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ai/tool/bin/anaconda3/lib/python3.6/site-packages/sklearn/model_selection/_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.58718, std: 0.00425, params: {'colsample_bytree': 0.6, 'subsample': 0.3},\n",
       "  mean: -0.58551, std: 0.00419, params: {'colsample_bytree': 0.6, 'subsample': 0.4},\n",
       "  mean: -0.58416, std: 0.00347, params: {'colsample_bytree': 0.6, 'subsample': 0.5},\n",
       "  mean: -0.58348, std: 0.00417, params: {'colsample_bytree': 0.6, 'subsample': 0.6},\n",
       "  mean: -0.58342, std: 0.00358, params: {'colsample_bytree': 0.6, 'subsample': 0.7},\n",
       "  mean: -0.58192, std: 0.00374, params: {'colsample_bytree': 0.6, 'subsample': 0.8},\n",
       "  mean: -0.58772, std: 0.00326, params: {'colsample_bytree': 0.7, 'subsample': 0.3},\n",
       "  mean: -0.58470, std: 0.00386, params: {'colsample_bytree': 0.7, 'subsample': 0.4},\n",
       "  mean: -0.58401, std: 0.00367, params: {'colsample_bytree': 0.7, 'subsample': 0.5},\n",
       "  mean: -0.58313, std: 0.00258, params: {'colsample_bytree': 0.7, 'subsample': 0.6},\n",
       "  mean: -0.58283, std: 0.00407, params: {'colsample_bytree': 0.7, 'subsample': 0.7},\n",
       "  mean: -0.58213, std: 0.00423, params: {'colsample_bytree': 0.7, 'subsample': 0.8},\n",
       "  mean: -0.58792, std: 0.00515, params: {'colsample_bytree': 0.8, 'subsample': 0.3},\n",
       "  mean: -0.58574, std: 0.00434, params: {'colsample_bytree': 0.8, 'subsample': 0.4},\n",
       "  mean: -0.58514, std: 0.00397, params: {'colsample_bytree': 0.8, 'subsample': 0.5},\n",
       "  mean: -0.58321, std: 0.00318, params: {'colsample_bytree': 0.8, 'subsample': 0.6},\n",
       "  mean: -0.58279, std: 0.00371, params: {'colsample_bytree': 0.8, 'subsample': 0.7},\n",
       "  mean: -0.58286, std: 0.00348, params: {'colsample_bytree': 0.8, 'subsample': 0.8},\n",
       "  mean: -0.58749, std: 0.00355, params: {'colsample_bytree': 0.9, 'subsample': 0.3},\n",
       "  mean: -0.58528, std: 0.00316, params: {'colsample_bytree': 0.9, 'subsample': 0.4},\n",
       "  mean: -0.58304, std: 0.00403, params: {'colsample_bytree': 0.9, 'subsample': 0.5},\n",
       "  mean: -0.58368, std: 0.00383, params: {'colsample_bytree': 0.9, 'subsample': 0.6},\n",
       "  mean: -0.58297, std: 0.00388, params: {'colsample_bytree': 0.9, 'subsample': 0.7},\n",
       "  mean: -0.58169, std: 0.00321, params: {'colsample_bytree': 0.9, 'subsample': 0.8}],\n",
       " {'colsample_bytree': 0.9, 'subsample': 0.8},\n",
       " -0.5816929480353772)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#对colsample_bytree和subsample进行GridSearchCV交叉验证\n",
    "xgb_5 = XGBClassifier(\n",
    "        learning_rate =0.1,\n",
    "        n_estimators=302,  #第二轮参数调整得到的n_estimators最优值\n",
    "        max_depth=5,\n",
    "        min_child_weight=4,\n",
    "        gamma=0,\n",
    "        subsample=0.3,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective= 'multi:softprob',\n",
    "        nthread=-1,\n",
    "        seed=3)\n",
    "gsearch5 = GridSearchCV(xgb_5, param_grid = param_5, scoring='neg_log_loss',n_jobs=-1, cv=kfold, return_train_score=True)\n",
    "gsearch5.fit(X_train , y_train)\n",
    "\n",
    "gsearch5.grid_scores_, gsearch5.best_params_,gsearch5.best_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "得到colsample_bytree最优解为0.9 ， subsample最优解为0.8\n",
    "两者都在上次测试时的边界处，所以要进行再次调试\n",
    "这次将步进设为0.05,边界最高到1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'colsample_bytree': [0.85, 0.9, 0.95, 1.0],\n",
       " 'subsample': [0.75, 0.8, 0.85, 0.9, 0.95, 1.0]}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subsample = [i/20.0 for i in range(15,21,1)]\n",
    "colsample_bytree = [i/20.0 for i in range(17,21,1)]\n",
    "param_6 = dict(subsample=subsample, colsample_bytree=colsample_bytree)\n",
    "param_6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ai/tool/bin/anaconda3/lib/python3.6/site-packages/sklearn/model_selection/_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.58227, std: 0.00447, params: {'colsample_bytree': 0.85, 'subsample': 0.75},\n",
       "  mean: -0.58241, std: 0.00441, params: {'colsample_bytree': 0.85, 'subsample': 0.8},\n",
       "  mean: -0.58216, std: 0.00362, params: {'colsample_bytree': 0.85, 'subsample': 0.85},\n",
       "  mean: -0.58219, std: 0.00350, params: {'colsample_bytree': 0.85, 'subsample': 0.9},\n",
       "  mean: -0.58166, std: 0.00417, params: {'colsample_bytree': 0.85, 'subsample': 0.95},\n",
       "  mean: -0.58359, std: 0.00375, params: {'colsample_bytree': 0.85, 'subsample': 1.0},\n",
       "  mean: -0.58175, std: 0.00364, params: {'colsample_bytree': 0.9, 'subsample': 0.75},\n",
       "  mean: -0.58169, std: 0.00321, params: {'colsample_bytree': 0.9, 'subsample': 0.8},\n",
       "  mean: -0.58206, std: 0.00392, params: {'colsample_bytree': 0.9, 'subsample': 0.85},\n",
       "  mean: -0.58205, std: 0.00394, params: {'colsample_bytree': 0.9, 'subsample': 0.9},\n",
       "  mean: -0.58235, std: 0.00351, params: {'colsample_bytree': 0.9, 'subsample': 0.95},\n",
       "  mean: -0.58285, std: 0.00379, params: {'colsample_bytree': 0.9, 'subsample': 1.0},\n",
       "  mean: -0.58228, std: 0.00340, params: {'colsample_bytree': 0.95, 'subsample': 0.75},\n",
       "  mean: -0.58247, std: 0.00451, params: {'colsample_bytree': 0.95, 'subsample': 0.8},\n",
       "  mean: -0.58198, std: 0.00395, params: {'colsample_bytree': 0.95, 'subsample': 0.85},\n",
       "  mean: -0.58248, std: 0.00356, params: {'colsample_bytree': 0.95, 'subsample': 0.9},\n",
       "  mean: -0.58242, std: 0.00356, params: {'colsample_bytree': 0.95, 'subsample': 0.95},\n",
       "  mean: -0.58323, std: 0.00354, params: {'colsample_bytree': 0.95, 'subsample': 1.0},\n",
       "  mean: -0.58249, std: 0.00373, params: {'colsample_bytree': 1.0, 'subsample': 0.75},\n",
       "  mean: -0.58284, std: 0.00408, params: {'colsample_bytree': 1.0, 'subsample': 0.8},\n",
       "  mean: -0.58223, std: 0.00394, params: {'colsample_bytree': 1.0, 'subsample': 0.85},\n",
       "  mean: -0.58171, std: 0.00285, params: {'colsample_bytree': 1.0, 'subsample': 0.9},\n",
       "  mean: -0.58248, std: 0.00374, params: {'colsample_bytree': 1.0, 'subsample': 0.95},\n",
       "  mean: -0.58313, std: 0.00335, params: {'colsample_bytree': 1.0, 'subsample': 1.0}],\n",
       " {'colsample_bytree': 0.85, 'subsample': 0.95},\n",
       " -0.5816579101593663)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "xgb_6 = XGBClassifier(\n",
    "        learning_rate =0.1,\n",
    "        n_estimators=302,  \n",
    "        max_depth=5,\n",
    "        min_child_weight=4,\n",
    "        gamma=0,\n",
    "        subsample=0.9,\n",
    "        colsample_bytree=0.8,\n",
    "        colsample_bylevel=0.7,\n",
    "        objective= 'multi:softprob',\n",
    "        nthread=-1,\n",
    "        seed=3)\n",
    "gsearch6 = GridSearchCV(xgb_6, param_grid = param_6, scoring='neg_log_loss',n_jobs=-1, cv=kfold, return_train_score=True)\n",
    "gsearch6.fit(X_train , y_train)\n",
    "\n",
    "gsearch6.grid_scores_, gsearch6.best_params_,gsearch6.best_score_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "得到colsample_bytree最优解为0.85,subsample最优解为0.95"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "5.对正则参数进行调优"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'reg_alpha': [1.5, 2, 2.5, 3], 'reg_lambda': [0.1, 0.5, 1, 2]}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#参考课程中参数并将范围调宽一点\n",
    "reg_alpha = [ 1.5, 2, 2.5, 3] \n",
    "reg_lambda = [0.1, 0.5, 1, 2]     \n",
    "\n",
    "param_7 = dict(reg_alpha=reg_alpha, reg_lambda=reg_lambda)\n",
    "param_7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ai/tool/bin/anaconda3/lib/python3.6/site-packages/sklearn/model_selection/_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.58133, std: 0.00353, params: {'reg_alpha': 1.5, 'reg_lambda': 0.1},\n",
       "  mean: -0.58095, std: 0.00344, params: {'reg_alpha': 1.5, 'reg_lambda': 0.5},\n",
       "  mean: -0.58150, std: 0.00379, params: {'reg_alpha': 1.5, 'reg_lambda': 1},\n",
       "  mean: -0.58114, std: 0.00373, params: {'reg_alpha': 1.5, 'reg_lambda': 2},\n",
       "  mean: -0.58085, std: 0.00352, params: {'reg_alpha': 2, 'reg_lambda': 0.1},\n",
       "  mean: -0.58091, std: 0.00333, params: {'reg_alpha': 2, 'reg_lambda': 0.5},\n",
       "  mean: -0.58135, std: 0.00380, params: {'reg_alpha': 2, 'reg_lambda': 1},\n",
       "  mean: -0.58159, std: 0.00379, params: {'reg_alpha': 2, 'reg_lambda': 2},\n",
       "  mean: -0.58110, std: 0.00335, params: {'reg_alpha': 2.5, 'reg_lambda': 0.1},\n",
       "  mean: -0.58135, std: 0.00300, params: {'reg_alpha': 2.5, 'reg_lambda': 0.5},\n",
       "  mean: -0.58121, std: 0.00324, params: {'reg_alpha': 2.5, 'reg_lambda': 1},\n",
       "  mean: -0.58089, std: 0.00363, params: {'reg_alpha': 2.5, 'reg_lambda': 2},\n",
       "  mean: -0.58107, std: 0.00313, params: {'reg_alpha': 3, 'reg_lambda': 0.1},\n",
       "  mean: -0.58115, std: 0.00305, params: {'reg_alpha': 3, 'reg_lambda': 0.5},\n",
       "  mean: -0.58130, std: 0.00342, params: {'reg_alpha': 3, 'reg_lambda': 1},\n",
       "  mean: -0.58137, std: 0.00354, params: {'reg_alpha': 3, 'reg_lambda': 2}],\n",
       " {'reg_alpha': 2, 'reg_lambda': 0.1},\n",
       " -0.5808541031705003)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#对正则参数进行GridSearchCV验证\n",
    "xgb_7 = XGBClassifier(\n",
    "        learning_rate =0.1,\n",
    "        n_estimators=302,  \n",
    "        max_depth=5,\n",
    "        min_child_weight=4,\n",
    "        gamma=0,\n",
    "        subsample=0.95,   #将上面得到的subsample最优解代入\n",
    "        colsample_bytree=0.85,  #将上面得到的colsample_bytree最优解代入\n",
    "        colsample_bylevel = 0.7,\n",
    "        objective= 'multi:softprob',\n",
    "        nthread=-1,\n",
    "        seed=3)\n",
    "\n",
    "\n",
    "gsearch7 = GridSearchCV(xgb_7, param_grid = param_7, scoring='neg_log_loss',n_jobs=-1, cv=kfold, return_train_score=True)\n",
    "gsearch7.fit(X_train , y_train)\n",
    "\n",
    "gsearch7.grid_scores_, gsearch7.best_params_, gsearch7.best_score_\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "得到reg_alpha最优解为2,而reg_lambda最优解在设置参数范围的边界，所以要再次做微调"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'reg_lambda': [0.01, 0.1, 0.5]}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#将reg_lambda范围进行调整\n",
    "reg_lambda = [0.01, 0.1, 0.5]      #default = 1，测试0.1， 0.5， 1，2\n",
    "\n",
    "param_8 = dict(reg_lambda=reg_lambda)\n",
    "param_8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ai/tool/bin/anaconda3/lib/python3.6/site-packages/sklearn/model_selection/_search.py:761: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20\n",
      "  DeprecationWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([mean: -0.58127, std: 0.00367, params: {'reg_lambda': 0.01},\n",
       "  mean: -0.58085, std: 0.00352, params: {'reg_lambda': 0.1},\n",
       "  mean: -0.58091, std: 0.00333, params: {'reg_lambda': 0.5}],\n",
       " {'reg_lambda': 0.1},\n",
       " -0.5808541031705003)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#对reg_lambda再次进行GridSearchCV验证\n",
    "xgb_8 = XGBClassifier(\n",
    "        learning_rate =0.1,\n",
    "        n_estimators=302,  \n",
    "        max_depth=5,\n",
    "        min_child_weight=4,\n",
    "        gamma=0,\n",
    "        reg_alpha=2,   #上面得到的reg_alpha最优值代入\n",
    "        subsample=0.95,\n",
    "        colsample_bytree=0.85,\n",
    "        colsample_bylevel = 0.7,\n",
    "        objective= 'multi:softprob',\n",
    "        seed=3)\n",
    "\n",
    "\n",
    "gsearch8 = GridSearchCV(xgb_8, param_grid = param_8, scoring='neg_log_loss',n_jobs=-1, cv=kfold, return_train_score=True)\n",
    "gsearch8.fit(X_train , y_train)\n",
    "\n",
    "gsearch8.grid_scores_, gsearch8.best_params_, gsearch8.best_score_\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "得到reg_lambda最优值为0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "6.调用模型进行测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "logloss of train :\n",
      "0.5041425080726454\n"
     ]
    }
   ],
   "source": [
    "#将上面调优得到的所有值代入XGBClassifier进行训练\n",
    "xgb_8 = XGBClassifier(\n",
    "        learning_rate =0.1,\n",
    "        n_estimators=302,  \n",
    "        max_depth=5,\n",
    "        min_child_weight=4,\n",
    "        gamma=0,\n",
    "        reg_alpha=2,\n",
    "        reg_lambda=0.1,   #上面得到的reg_lambda最优值代入\n",
    "        subsample=0.95,\n",
    "        colsample_bytree=0.85,\n",
    "        colsample_bylevel = 0.7,\n",
    "        objective= 'multi:softprob',\n",
    "        nthread=-1,\n",
    "        seed=3)\n",
    "#Fit the algorithm on the data\n",
    "xgb_8.fit(X_train, y_train, eval_metric='mlogloss')\n",
    "        \n",
    "#Predict training set:\n",
    "train_predprob = xgb_8.predict_proba(X_train)\n",
    "logloss = log_loss(y_train, train_predprob)\n",
    "        \n",
    "#Print model report:\n",
    "print(\"logloss of train :\" )\n",
    "print(logloss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "将reg_lambda代入后得到的logloss比之前好很多，有点不敢相信，感觉哪里错了"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "#对测试集进行预测\n",
    "test_predprob = xgb_8.predict_proba(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.11884218, 0.3804775 , 0.5006803 ],\n",
       "       [0.24092482, 0.43605596, 0.3230192 ],\n",
       "       [0.0484583 , 0.10842803, 0.84311366],\n",
       "       ...,\n",
       "       [0.08207287, 0.27230906, 0.6456181 ],\n",
       "       [0.43026105, 0.4878895 , 0.08184951],\n",
       "       [0.04403738, 0.39451763, 0.56144494]], dtype=float32)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#查看预测结果\n",
    "test_predprob"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "结果为三列概率数组"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "7.生成测试结果文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "#将测试出的结果保存到output_result.csv\n",
    "pd.DataFrame(test_predprob).to_csv('output_result.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于本次作业的运算太耗时，作业开始的又比较晚，没有时间做更多的分析研究，待后成补上\n",
    "本次作业有以下几个疑问，请老师帮忙解答一下：\n",
    "1.dtrain = xgb.DMatrix(dpath + 'RentListingInquries_FE_train.bin')得到的训练数据，如何分离出X_train？\n",
    "  一直没找到方法，最后只好把RentListingInquries_FE_train.csv导入再得到X_train\n",
    "2.dtest的行数和dtrain的行数一样，但是dtest少了label，为何不一样？\n",
    "3.XGBClassifier中有很多参数：subsample、reg_alpha、reg_lambda、colsample_bytree之类，而xgboost.cv中并没有，通过XGBClassifier.get_xgb_params()来设置xgboost.cv，这些参数能起到作用吗\n",
    "4.在重新调整弱学习器数目时，如果num_boost_round=1000,则得到的n_estimators=334,大于302,为何会这样？"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
